package com.googlecode.gaal.suffix.algorithm.impl;

import com.googlecode.gaal.data.api.IntSequence;
import com.googlecode.gaal.data.impl.ArraySequence;
import com.googlecode.gaal.suffix.algorithm.api.SuffixTableBuilder;

/**
 * Kärkkäinen et al. algorithm for suffix table construction
 * 
 * @author Alex Kislev
 * 
 */
public class SkewSuffixTableBuilder implements SuffixTableBuilder {

    /*
     * (non-Javadoc)
     * 
     * @see
     * com.googlecode.gaal.algorithm.api.SuffixTableBuilder#buildSuffixTable
     * (com.googlecode.gaal.data.api.Sequence, int) Requirements: 1) text must
     * be padded at the end with at least 3 elements 2) the zeroth symbol in the
     * symbol table can not be used
     */
    @Override
    public int[] buildSuffixTable(IntSequence text, int alphabetSize) {
        int[] suffixTable = new int[text.size()];
        suffixTable[0] = -1;
        sort(text, suffixTable, text.size(), alphabetSize);
        return suffixTable;
    }

    /**
     * find the suffix array sa of s[0..n-1] in {1..radix}^n require
     * s[n]=s[n+1]=s[n+2]=0, n>=2
     * 
     * @param s
     *            the array to be sorted
     * @param sa
     *            the result array
     * @param n
     *            the number of elements to be sorted in s
     * @param radix
     *            the alphabet size
     */
    private void sort(IntSequence s, int[] sa, int n, int radix) {
        int n0 = (n + 2) / 3, n1 = (n + 1) / 3, n2 = n / 3, n02 = n0 + n2;
        int[] s12 = new int[n02 + 3];
        // s12[n02] = s12[n02 + 1] = s12[n02 + 2] = 0;
        int[] sa12 = new int[n02 + 3];
        // SA12[n02] = SA12[n02 + 1] = SA12[n02 + 2] = 0;
        int[] s0 = new int[n0];
        int[] SA0 = new int[n0];

        // generate positions of mod 1 and mod 2 suffixes
        // the "+(n0-n1)" adds a dummy mod 1 suffix if n%3 == 1
        for (int i = 0, j = 0; i < n + (n0 - n1); i++)
            if (i % 3 != 0)
                s12[j++] = i;

        // lsb radix sort the mod 1 and mod 2 triples
        radixPass(s12, sa12, s, 2, n02, radix);
        radixPass(sa12, s12, s, 1, n02, radix);
        radixPass(s12, sa12, s, 0, n02, radix);

        // find lexicographic names of triples
        int name = 0, c0 = -1, c1 = -1, c2 = -1;
        for (int i = 0; i < n02; i++) {
            if (s.get(sa12[i], 0) != c0 || s.get(sa12[i] + 1, 0) != c1 || s.get(sa12[i] + 2, 0) != c2) {
                name++;
                c0 = s.get(sa12[i], 0);
                c1 = s.get(sa12[i] + 1, 0);
                c2 = s.get(sa12[i] + 2, 0);
            }
            if (sa12[i] % 3 == 1) {
                s12[sa12[i] / 3] = name;
            } // left half
            else {
                s12[sa12[i] / 3 + n0] = name;
            } // right half
        }

        // recurse if names are not yet unique
        if (name < n02) {
            sort(new ArraySequence(s12), sa12, n02, name);
            // store unique names in s12 using the suffix array
            for (int i = 0; i < n02; i++)
                s12[sa12[i]] = i + 1;
        } else
            // generate the suffix array of s12 directly
            for (int i = 0; i < n02; i++)
                sa12[s12[i] - 1] = i;

        // stably sort the mod 0 suffixes from SA12 by their first character
        for (int i = 0, j = 0; i < n02; i++)
            if (sa12[i] < n0)
                s0[j++] = 3 * sa12[i];
        radixPass(s0, SA0, s, 0, n0, radix);

        // merge sorted SA0 suffixes and sorted SA12 suffixes
        for (int p = 0, t = n0 - n1, k = 0; k < n; k++) {
            int i = getI(sa12, n0, t); // pos of current offset 12 suffix
            int j = SA0[p]; // pos of current offset 0 suffix
            if (sa12[t] < n0 ? leq(s.get(i, 0), s12[sa12[t] + n0], s.get(j, 0), s12[j / 3]) : leq(s.get(i, 0),
                    s.get(i + 1, 0), s12[sa12[t] - n0 + 1], s.get(j, 0), s.get(j + 1, 0), s12[j / 3 + n0])) { // suffix
                                                                                                              // from
                                                                                                              // SA12
                                                                                                              // is
                // smaller
                sa[k] = i;
                t++;
                if (t == n02) { // done --- only SA0 suffixes left
                    for (k++; p < n0; p++, k++)
                        sa[k] = SA0[p];
                }
            } else {
                sa[k] = j;
                p++;
                if (p == n0) { // done --- only SA12 suffixes left
                    for (k++; t < n02; t++, k++)
                        sa[k] = getI(sa12, n0, t);
                }
            }
        }
    }

    /**
     * One pass of the Radix LSD sort Stably sort a[0..n-1] to b[0..n-1] with
     * keys in 0..radix from keys
     * 
     * @param a
     *            the array to be sorted
     * @param b
     *            the result array
     * @param keys
     *            the keys array (the input text)
     * @param digit
     *            the digit to sort by
     * @param n
     *            the number of elements to be sorted in a
     * @param radix
     *            the alphabet size
     */
    private static void radixPass(int[] a, int[] b, IntSequence keys, int digit, int n, int radix) {

        int[] counter = new int[radix + 1];

        // count occurrences
        for (int i = 0; i < n; i++)
            counter[keys.get(a[i] + digit, 0)]++;

        // exclusive prefix sums
        for (int i = 0, sum = 0; i <= radix; i++) {
            int t = counter[i];
            counter[i] = sum;
            sum += t;
        }
        // sort
        for (int i = 0; i < n; i++)
            b[counter[keys.get(a[i] + digit, 0)]++] = a[i];

    }

    private static boolean leq(int a1, int a2, int b1, int b2) {
        // lexical order for pairs and triples
        return (a1 < b1 || a1 == b1 && a2 <= b2);
    }

    private static boolean leq(int a1, int a2, int a3, int b1, int b2, int b3) {
        return (a1 < b1 || a1 == b1 && leq(a2, a3, b2, b3));
    }

    private static int getI(int[] sa12, int n0, int t) {
        return (sa12[t] < n0 ? sa12[t] * 3 + 1 : (sa12[t] - n0) * 3 + 2);
    }
}
